# -*- coding: utf-8 -*-
"""
Created on Thu Nov  7 09:32:47 2024

@author: aharding6
"""


import matplotlib.pyplot as plt
from scipy.optimize import minimize
import numpy as np



# =========================================== Parameters =========================================== #
tstep = 1                                   #Years per Period
T = int(127/tstep)                          #Modeling time horizon


## Preferences  ##
elasmu = 1.45                               #Elasticity of marginal utility of consumption
prstp = 0.015                               #Initial rate of social time preference per year


## == Population and technology == ##
gama = 0.33                                  #Capital elasticity in production function 
pop0 = 6838.00                              #Initial world population (millions)
popadj = 0.134*tstep                      #Growth rate to calibrate to 2050 pop projection
popasym = 10500.00                          #Asymptotic population (millions)
dk = 0.08                                   #Depreciation rate on capital (per year)
q0 = 63.69                                  #Initial world gross output (trill 2005 USD)
k0 = 135.00                                 #Initial capital value (trill 2005 USD)
a0 = 3.80                                   #Initial level of total factor productivity      
ga0 = 0.0790/5                                #Initial growth rate for TFP per 5 years 
dela = 0.0060/5                               #Decline rate of TFP per 5 years
optlrsav = (dk + .004)/(dk + .004*elasmu + prstp)*gama       #Optimal long-run savings rate used for transversality
ActS0 = optlrsav

## == Climate model parameters == ##
tatm0 = 0                                #Initial atmospheric temp change (C from 2024)


# ============================================ State Variables ============================================ #

 # == Capital ($trill, 2005$) == #
K_DICE = [k0]*T
Tat_DICE = list(np.arange(0,1,1/26))+list(np.arange(1,1.75,0.75/25))+list(np.arange(1.75,2.15,0.4/76))




# ============================================ Exogenous Variables ============================================ #

# == Population == #
l = [0]*T                                   #Level of population and labor

# == Technology == #
al = [0]*T                                  #Level of total factor productivity
ga = [0]*T                                  #Growth rate of productivity


rr = [1]*T                                  #Average utility social discount rate






# ========================================= Exogenous variables ========================================= #

for i in range(T):
    if i == 0:
        al[i] = a0
        l[i] = pop0
    else:
        rr[i] = rr[i-1]/(1 + prstp)**tstep
        al[i] = al[i-1]/(1 - ga[i-1])**(tstep/5)
        l[i] = l[i-1]*(popasym/l[i-1])**popadj
    ga[i] = ga0 * np.exp(-dela * tstep * i)


# ============================================ Output Variables ============================================ #
Yt_DICE = [q0]*T                            #Output gross of abatement cost and climate damage ($trill)
Qt_DICE = [0]*T                             #Output net of abatement cost and climate damage ($trill)
Dm_DICE = [0]*T                             #Total damage (fraction of gross output)
Dt_DICE = [0]*T                             #Climate damages (trillion $)
Sv_DICE = [0]*T                             #Saving ($trill, 2005$)
Cn_DICE = [0]*T                             #Consumption ($trill per year)
cn_DICE = [0]*T                             #Consumption per capita ($thous per year)
Ut_DICE = [0]*T                             #Total period utility
ut_DICE = [0]*T                             #Utility of p. c. consumption
Cp_DICE = [0]*T                             #Carbon price (2005$ per ton of CO2)





def state( ST1, EX1, Act1 ):
    
    [K1, Tatx1] = ST1
    [A1, L1, df1] = EX1
    [xs1] = Act1
    
    Y2 = A1 * (L1/1000)**(1 - gama) * K1**gama
    
    # psis = [122.747*(np.exp(-0.3291628*s)-np.exp(-0.3018996*s)) for s in range(1,11)]
    psis = [0]*10
    
    A2 = A1*(1+sum(np.array(psis[::-1])*np.array(Tatx1))/100)
    Y1 = A2 * (L1/1000)**(1 - gama) * K1**gama
    
#    Dm1 = (a1 * Tatx1 + a2 * Tatx1**a3) + (nuGx * xg1**2)
    Dm1 = (1+sum(np.array(psis[::-1])*np.array(Tatx1))/100)
       
    D1 = Y2 - Y1
    
    Q1 = Y1
    
    S1 = xs1 * Q1
    
    Con1 = Q1 - S1
    
    con1 = (Con1/L1) * 1000
    
    u1 = (con1**(1 - elasmu) - 1)/ (1 - elasmu) - 1
    
    U1 = u1 * L1 * df1
  
    K2 = (1 - dk)**tstep * K1 + tstep * S1

    ST2 = [K2, Tatx1]
    LEV2 = [Y1, Q1, Dm1, D1, S1, Con1, con1, U1, u1]
    STGE1 = [A2]
    
    return ( ST2, LEV2, STGE1 )

def stateCC( ST1, EX1, Act1 ):
    
    [K1, Tatx1] = ST1
    [A1, L1, df1] = EX1
    [xs1] = Act1
    
    Y2 = A1 * (L1/1000)**(1 - gama) * K1**gama
    
    psis = [122.747*(np.exp(-0.3291628*s)-np.exp(-0.3018996*s)) for s in range(1,11)]
    # psis = [0]*10
    
    A2 = A1*(1+sum(np.array(psis[::-1])*np.array(Tatx1))/100)
    Y1 = A2 * (L1/1000)**(1 - gama) * K1**gama
    
#    Dm1 = (a1 * Tatx1 + a2 * Tatx1**a3) + (nuGx * xg1**2)
    Dm1 = (1+sum(np.array(psis[::-1])*np.array(Tatx1))/100)
       
    D1 = Y2 - Y1
    
    Q1 = Y1
    
    S1 = xs1 * Q1
    
    Con1 = Q1 - S1
    
    con1 = (Con1/L1) * 1000
    
    u1 = (con1**(1 - elasmu) - 1)/ (1 - elasmu) - 1
    
    U1 = u1 * L1 * df1
  
    K2 = (1 - dk)**tstep * K1 + tstep * S1

    ST2 = [K2, Tatx1]
    LEV2 = [Y1, Q1, Dm1, D1, S1, Con1, con1, U1, u1]
    STGE1 = [A2]
    
    return ( ST2, LEV2, STGE1 )



# =========================================== Welfare Function (DICE model) =========================================== #

def fDICE(v):  
    W = 0
    for i in range(T-1):
        if i<10:
            Tat_DICE_temp = [0]*(10-i)+Tat_DICE[0:i]
        else:
            Tat_DICE_temp = Tat_DICE[i-10:i]
        STi = [K_DICE[i], Tat_DICE_temp]
        EXi = [al[i], l[i], rr[i]]
        if i >= T-10:
            Acti = [ActS0]
        elif i == 0:
            Acti = [ActS0]
        else:
            Acti = [v[i]]
        
        ( STii, LEVii, STGEii ) = state( STi, EXi, Acti )
        K_DICE[i + 1] = STii[0]
        [Yt_DICE[i], Qt_DICE[i], Dm_DICE[i], Dt_DICE[i], Sv_DICE[i], Cn_DICE[i], cn_DICE[i], Ut_DICE[i], ut_DICE[i]] = LEVii
        
        W = W +  Ut_DICE[i] * tstep
    return -W

def fDICE2(v):  
    W = 0
    al2 = [0]*T                                  #Level of total factor productivity
    for i in range(T-1):
        if i<10:
            Tat_DICE_temp = [0]*(10-i)+Tat_DICE[0:i]
        else:
            Tat_DICE_temp = Tat_DICE[i-10:i]
        STi = [K_DICE[i], Tat_DICE_temp]
        EXi = [al[i], l[i], rr[i]]
        if i >= T-10:
            Acti = [ActS0]
        elif i == 0:
            Acti = [ActS0]
        else:
            Acti = [v[i]]
        
        ( STii, LEVii, STGEii ) = state( STi, EXi, Acti )
        K_DICE[i + 1] = STii[0]
        al2[i] = STGEii[0]
        [Yt_DICE[i], Qt_DICE[i], Dm_DICE[i], Dt_DICE[i], Sv_DICE[i], Cn_DICE[i], cn_DICE[i], Ut_DICE[i], ut_DICE[i]] = LEVii
        
        W = W +  Ut_DICE[i] * tstep
    return -W,Qt_DICE,Cn_DICE,Sv_DICE,Dm_DICE,al2,K_DICE,Ut_DICE

def fDICECC(v):  
    W = 0
    for i in range(T-1):
        if i<10:
            Tat_DICE_temp = [0]*(10-i)+Tat_DICE[0:i]
        else:
            Tat_DICE_temp = Tat_DICE[i-10:i]
        STi = [K_DICE[i], Tat_DICE_temp]
        EXi = [al[i], l[i], rr[i]]
        if i >= T-10:
            Acti = [ActS0]
        elif i == 0:
            Acti = [ActS0]
        else:
            Acti = [v[i]]
        
        ( STii, LEVii, STGEii ) = stateCC( STi, EXi, Acti )
        K_DICE[i + 1] = STii[0]
        [Yt_DICE[i], Qt_DICE[i], Dm_DICE[i], Dt_DICE[i], Sv_DICE[i], Cn_DICE[i], cn_DICE[i], Ut_DICE[i], ut_DICE[i]] = LEVii
        
        W = W +  Ut_DICE[i] * tstep
    return -W
    
def fDICE2CC(v):  
    W = 0
    al2 = [0]*T                                  #Level of total factor productivity
    for i in range(T-1):
        if i<10:
            Tat_DICE_temp = [0]*(10-i)+Tat_DICE[0:i]
        else:
            Tat_DICE_temp = Tat_DICE[i-10:i]
        STi = [K_DICE[i], Tat_DICE_temp]
        EXi = [al[i], l[i], rr[i]]
        if i >= T-10:
            Acti = [ActS0]
        elif i == 0:
            Acti = [ActS0]
        else:
            Acti = [v[i]]
        
        ( STii, LEVii, STGEii ) = stateCC( STi, EXi, Acti )
        K_DICE[i + 1] = STii[0]
        al2[i] = STGEii[0]
        [Yt_DICE[i], Qt_DICE[i], Dm_DICE[i], Dt_DICE[i], Sv_DICE[i], Cn_DICE[i], cn_DICE[i], Ut_DICE[i], ut_DICE[i]] = LEVii
        
        W = W +  Ut_DICE[i] * tstep
    return -W,Qt_DICE,Cn_DICE,Sv_DICE,Dm_DICE,al2,K_DICE,Ut_DICE

# ======================================== Optimization Algorithm (DICE model) ======================================== #

x0 = T * [optlrsav]
# == bounds == #
bnds_DICE = T * [(0.0, 1.0)]
bnds_DICE[0] = (ActS0, ActS0)
# bnds_DICE[T-1] = (ActS0, ActS0)

# == optimization == #
ftol = 1e-12
eps = 1e-6
maxiter = 10000
    

# psis = [0]*10


# psis = [122.747*(np.exp(-0.3291628*s)-np.exp(-0.3018996*s)) for s in range(1,11)]

resCC = minimize(fDICECC, x0, method='SLSQP', bounds=bnds_DICE, options={'ftol': ftol, 'eps': eps, 'disp': True, 'maxiter': maxiter})
resDICECC = resCC.x

WDiceCC,YDiceCC,CDiceCC,S_DICECC,D_DICECC,al2CC,K_DICECC,Ut_DICECC = fDICE2CC(resDICECC)

YDiceCC = YDiceCC.copy()
CDiceCC = CDiceCC.copy()
S_DICECC = S_DICECC.copy()
D_DICECC = D_DICECC.copy()
K_DICECC = K_DICECC.copy()
Ut_DICECC = Ut_DICECC.copy()


res = minimize(fDICE, x0, method='SLSQP', bounds=bnds_DICE, options={'ftol': ftol, 'eps': eps, 'disp': True, 'maxiter': maxiter})
resDICEnoCC = res.x

WDiceNoCC,YDiceNoCC,CDiceNoCC,S_DICENoCC,D_DICENoCC,al2NoCC,K_DICENoCC,Ut_DICENoCC = fDICE2(resDICEnoCC)


years = range(2024,2151)
colors=['#377eb8']

# Create Plot
fig,axes = plt.subplots(2,figsize=(27,15))
# Output
axes[0].set_title('Output')
axes[0].set_xlabel('Year')
axes[0].set_ylabel('Output')
axes[0].plot(years,YDiceCC,marker = '',color='r',linestyle = '-',linewidth=3)
axes[0].plot(years,YDiceNoCC,marker = '',color='b',linestyle = '-',linewidth=3)
axes[0].set_xlim(2024,2150)
# axes[0].set_ylim(0,9)
# Consumption
axes[1].set_title('Consumption')
axes[1].set_xlabel('Year')
axes[1].set_ylabel('Consumption')
axes[1].plot(years,CDiceCC,marker = '',color='r',linestyle = '-',linewidth=3)
axes[1].plot(years,CDiceNoCC,marker = '',color='r',linestyle = '-',linewidth=3)
axes[1].set_xlim(2024,2150)
# axes[1].set_ylim(-40,80)




# Create Plot
fig,axes = plt.subplots(2,figsize=(27,15))
# Output
axes[0].set_title('Output')
axes[0].set_xlabel('Year')
axes[0].set_ylabel('Output')
axes[0].plot(years,100*((np.array(YDiceCC)-np.array(YDiceNoCC))/np.array(YDiceNoCC)),marker = '',color='r',linestyle = '-',linewidth=3)
axes[0].set_xlim(2024,2150)
# axes[0].set_ylim(0,9)
# Consumption
axes[1].set_title('Consumption')
axes[1].set_xlabel('Year')
axes[1].set_ylabel('Consumption')
axes[1].plot(years,100*((np.array(K_DICECC)-np.array(K_DICENoCC))/np.array(K_DICENoCC)),marker = '',color='r',linestyle = '-',linewidth=3)
axes[1].set_xlim(2024,2150)
# axes[1].set_ylim(-40,80)



# Create Plot
fig,axes = plt.subplots(2,3,figsize=(27,15))
# Output
axes[0,0].set_title('Output')
axes[0,0].set_xlabel('Year')
axes[0,0].set_ylabel('Output')
axes[0,0].plot(years,Tat_DICE,marker = '',color='r',linestyle = '-',linewidth=3)
axes[0,0].set_xlim(2024,2150)
# Output
axes[0,1].set_title('Output')
axes[0,1].set_xlabel('Year')
axes[0,1].set_ylabel('Output')
axes[0,1].plot(years,100*((np.array(YDiceCC)-np.array(YDiceNoCC))/np.array(YDiceNoCC)),marker = '',color='r',linestyle = '-',linewidth=3)
axes[0,1].set_xlim(2024,2150)
# axes[0].set_ylim(0,9)
# Capital
axes[0,2].set_title('Capital')
axes[0,2].set_xlabel('Year')
axes[0,2].set_ylabel('Capital')
axes[0,2].plot(years,100*((np.array(K_DICECC)-np.array(K_DICENoCC))/np.array(K_DICENoCC)),marker = '',color='r',linestyle = '-',linewidth=3)
axes[0,2].set_xlim(2024,2150)
# axes[0].set_ylim(0,9)
# Consumption
axes[1,0].set_title('Consumption')
axes[1,0].set_xlabel('Year')
axes[1,0].set_ylabel('Consumption')
axes[1,0].plot(years,100*((np.array(CDiceCC)-np.array(CDiceNoCC))/np.array(CDiceNoCC)),marker = '',color='r',linestyle = '-',linewidth=3)
axes[1,0].set_xlim(2024,2150)
# axes[1].set_ylim(-40,80)
# Utility
axes[1,1].set_title('Welfare')
axes[1,1].set_xlabel('Year')
axes[1,1].set_ylabel('Welfare')
axes[1,1].plot(years,100*((np.array(Ut_DICECC)-np.array(Ut_DICENoCC))/np.array(Ut_DICENoCC)),marker = '',color='r',linestyle = '-',linewidth=3)
axes[1,1].set_xlim(2024,2150)
# axes[1].set_ylim(-40,80)

